from utils.logger import Logger
import os
import torch
from transformers import (
    AutoTokenizer,
    AutoModel,
)
from llm.model.local_llm import LocalLLM
from llm.model.api_llm import ApiLLM
from llm.llm_const import get_local_model_name

import threading

class Auxiliary:

    def __init__(self, logger: Logger, config):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.logger = logger

        self.local_generate_model = None
        self.local_embedding_model = None
        self.api_generate_model = None
        self.api_embedding_model = None

        self.local_model_save_path = config.get("local_model_save_path", "")
        self.local_model_cache_path = config.get("local_model_cache_path", "")
        self.local_generate_model_id = config.get("local_generate_model_id", "")
        self.local_embedding_model_id = config.get("local_embedding_model_id", "")

        self.api_generate_model_id = config.get("api_generate_model_id", "")
        self.generate_api_key = config.get("generate_api_key", "")
        self.api_embedding_model_id = config.get("api_embedding_model_id", "")
        self.embedding_api_key = config.get("embedding_api_key", "")
        

        self.local_generate_lock = threading.Lock()
        self.local_embedding_lock = threading.Lock()
        self.api_generate_lock = threading.Lock()
        self.api_embedding_lock = threading.Lock()
    
    def set_local_genereate_model(self, local_llm):
        if not self.local_generate_model:
            with self.local_generate_lock:
                if not self.local_generate_model:
                    self.local_generate_model = local_llm
    
    def get_local_genereate_model(self) -> LocalLLM:
        if not self.local_generate_model:
            model_id = self.local_generate_model_id
            if not model_id:
                self.logger.warning("Local embedding model id is None")
                return None
 
            with self.local_generate_lock:
                if not self.local_generate_model:
                    self.local_generate_model = self._init_local_generate_model(
                        model_id,
                        self.local_model_save_path,
                        self.local_model_cache_path,
                    )
        return self.local_generate_model
    
    def get_local_embedding_model(self) -> LocalLLM:
        if not self.local_embedding_model:
            model_id = self.local_embedding_model_id
            if not model_id:
                self.logger.warning("Local embedding model id is None")
                return None
        
            with self.local_embedding_lock:
                if not self.local_embedding_model:
                    self.local_embedding_model = self._init_local_embedding_model(
                        model_id,
                        self.local_model_save_path,
                        self.local_model_cache_path,
                    )
        return self.local_embedding_model

    def get_api_generate_model(self) -> ApiLLM:
        if not self.api_generate_model:
            with self.api_generate_lock:
                if not self.api_generate_model:
                    self.api_generate_model = self._init_api_generate_model(
                        self.api_generate_model_id, self.generate_api_key, self.logger
                    )
        return self.api_generate_model
    
    def get_api_embedding_model(self) -> ApiLLM:
        if not self.api_embedding_model:
            with self.api_embedding_lock:
                if not self.api_embedding_model:
                    self.api_embedding_model = self._init_api_generate_model(
                        self.api_embedding_model_id, self.embedding_api_key, self.logger
                    )
        return self.api_embedding_model

    def _init_local_embedding_model(
        self, model_id, model_save_path, model_cache_path
    ):
        model_name = get_local_model_name(model_id)
        if not model_name:
            self.logger.warning("Cannot find local model")
            return (None, None)
        model_dir = os.path.join(model_save_path, model_name)
        self.logger.info(f"Loading embedding model from {model_dir}")

        if not os.path.exists(model_dir):
            os.makedirs(model_dir, exist_ok=True)
            model = AutoModel.from_pretrained(
                model_id,
                cache_dir=model_cache_path,
                torch_dtype="auto",
                device_map="auto",
            )
            tokenizer = AutoTokenizer.from_pretrained(
                model_id, cache_dir=model_cache_path
            )

            model.save_pretrained(model_dir)
            tokenizer.save_pretrained(model_dir)
        else:
            model = AutoModel.from_pretrained(
                model_dir, torch_dtype="auto", device_map="auto"
            )
            tokenizer = AutoTokenizer.from_pretrained(model_dir)

        return (model, tokenizer)

    def _init_local_generate_model(
        self, model_id, model_save_path, model_cache_path
    ):
        generate_model = LocalLLM(model_id, model_save_path, self.logger, model_cache_path)
        return generate_model

    def _init_api_generate_model(self, model_id, api_key, logger):
        generate_model = ApiLLM(model_id, api_key, logger)
        return generate_model

    def local_embedding(self, text):
        model, tokenizer = self.get_local_embedding_model()
        inputs = tokenizer(
            text, return_tensors="pt", padding=True, truncation=True
        ).to(self.device)
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)

        hidden_states = outputs.hidden_states
        last_hidden = hidden_states[-1]
        embeddings = torch.mean(last_hidden, dim=1)
        return embeddings
    
    def api_embedding(self, text):
        return self.get_api_embedding_model().embedding(text)

    def embedding_similarity(self, emb1, emb2):
        similarity = torch.nn.functional.cosine_similarity(emb1, emb2, dim=0)
        return similarity

    def local_perplexity(self, text):
        return self.get_local_genereate_model().perplexity(text)
    
    def api_perplexity(self, text):
        return self.get_api_generate_model().perplexity(text)
